# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from typing import List, Optional, Tuple

import torch

from megatron.core.tensor_parallel import (
    all_to_all,
    gather_from_sequence_parallel_region,
    reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.moe.moe_utils import (
    get_capacity,
    permute,
    sort_chunks_by_idxs,
    unpermute,
)
from megatron.core.transformer.transformer_config import TransformerConfig


#HACK(huang.huang):use pin_memory, non_blocking, cuda.stream to optimze D2H
#class MoEAlltoAllTokenDispatcher:
#   function update: __init__, preprocess, token_permutation, token_unpermutation
#       __init__: extra opration after ori_init
#           a. copy sort_input_by_local_experts and restore_output_by_local_experts to cpu, prevent redunt D2H in each Forward
#           b. add self.cuda_dtoh_stream to support async copy
#           c. follow Megatron:main, add cuda_sync_point_priority and cuda_dtoh_point to manage sync point
#       preprocess:
#           a. remove D2H such as .cpu() and do it in token_permutation as once
#           b. use new method to manage sync point
#           c. don't set self.cuda_sync_point to "before_permutation_1" when not used drop and pad, while it's seems to be unnecessary
#       token_permutation:
#           a. do ravel for num_global_tokens_per_local_expert and num_global_tokens_per_local_expert_T before D2H
#           b. pass numpy to sort_chunks instead call tolist for gpu tensor
#       token_unpermutation:
#           a. pass numpy to sort_chunks instead call tolist for gpu tensor
#
#   new function: _maybe_update_cuda_sync_point, _ensure_cpu_tensor, _maybe_dtoh_and_synchronize
#       _maybe_update_cuda_sync_point: update sync_point if this point is reached earlier than current one
#       _ensure_cpu_tensor: 
#           a. use pin_memory to alloc cpu tensor, and copy value in gpu tensor to it
#           b. will be used under cuda_dtoh_stream to prevent block
#       _maybe_dtoh_and_synchronize:
#           a. call _ensure_cpu_tensor in dtoh_point, which equals to "before_pemutation1" expert_capacity used
#           b. in self.cuda_sync_point, do cuda_dtoh_stream.synchronize(), .numpy, tolist where synchronize needs to be done before
#           .numpy and tolist, otherwise D2H maybe not be finished. 
def MoEAlltoAllTokenDispatcher___init__(
    self, num_local_experts: int, local_expert_indices: List[int], config: TransformerConfig
) -> None:
    self._orig___init__(num_local_experts, local_expert_indices, config)
    
    self.sort_input_by_local_experts = self.sort_input_by_local_experts.cpu().numpy()
    self.restore_output_by_local_experts = self.restore_output_by_local_experts.cpu().numpy()
    self.cuda_sync_point_priority = {
        "before_permutation_1": 0,
        "before_ep_alltoall": 1,
        "before_permutation_2": 2,
        "before_finish": 3,
        "no_sync": 4,
    }
    self.cuda_dtoh_point = "before_permutation_1"
    self.cuda_dtoh_stream = torch.cuda.Stream()

def MoEAlltoAllTokenDispatcher__maybe_update_cuda_sync_point(self, point: str):
    """
    Update the CUDA sync point if the priority of the new point is higher than the current
    sync point, which means the new point is reached earlier than the current sync point.
    """
    if (
        self.cuda_sync_point_priority[point]
        < self.cuda_sync_point_priority[self.cuda_sync_point]
    ):
        self.cuda_sync_point = point

def MoEAlltoAllTokenDispatcher__ensure_cpu_tensor(self, cpu_attr_name, gpu_tensor):
    if gpu_tensor is None:
        return
    cpu_tensor = getattr(self, cpu_attr_name, None)
    if cpu_tensor is None:
        cpu_tensor = torch.empty(
            gpu_tensor.size(),
            device="cpu",
            pin_memory=True,
            dtype=gpu_tensor.dtype
        )
        setattr(self, cpu_attr_name, cpu_tensor)
    cpu_tensor.copy_(gpu_tensor, non_blocking=True)
    gpu_tensor.record_stream(torch.cuda.current_stream())

def MoEAlltoAllTokenDispatcher__maybe_dtoh_and_synchronize(
    self, point: str, tokens_per_expert: torch.Tensor = None,
    num_global_tokens_per_local_expert: torch.Tensor = None,
    num_global_tokens_per_local_expert_T: torch.Tensor = None,
):
    """
    Move all possible GPU tensors to CPU and make a synchronization at the expected point.
    """
    if not self.drop_and_pad:
        if point == self.cuda_dtoh_point:
            # Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
            on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream
            if on_side_stream:
                self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(self.cuda_dtoh_stream):
                # TODO: use MemcpyBatchAsync instead.
                self._ensure_cpu_tensor('tokens_per_expert_cpu', tokens_per_expert)
                if self.ep_size > 1 or self.tp_size > 1:
                    self._ensure_cpu_tensor('output_splits_tp_cpu', self.output_splits_tp)
                    self._ensure_cpu_tensor('input_splits_cpu', self.input_splits)
                    self._ensure_cpu_tensor('output_splits_cpu', self.output_splits)
                #NOTE(huang.huang): only fuse-moe-permute need num_out_tokens, but in that case fused-kernel will do D2H itself,
                # if we want to do this D2H here, then we need to sync stream before permute1, which resulted in D2H not being well overlaped
                # self._ensure_cpu_tensor('num_out_tokens_cpu', self.num_out_tokens)
                if self.num_local_experts > 1:
                    self._ensure_cpu_tensor('num_global_tokens_per_local_expert_cpu', num_global_tokens_per_local_expert)
                    self._ensure_cpu_tensor('num_global_tokens_per_local_expert_T_cpu', num_global_tokens_per_local_expert_T)
        
        if point == self.cuda_sync_point:
            # Synchronize with the dtoh stream at self.cuda_sync_point.
            self.cuda_dtoh_stream.synchronize()
            # Need to do before sync, otherwise copy for value in num_global_tokens_per_local_expert_cpu is not finished
            # self.num_out_tokens = self.num_out_tokens_cpu.numpy()
            
            tokens_per_expert = self.tokens_per_expert_cpu.numpy().copy()
            # need copy(), because recompute groupedLinear1 save it as msplit, value in ctx will change while next copy gpu data to _cpu
            # not use tolist, since expert will call tolist again, otherwise we need to modify experts.py
            
            if self.ep_size > 1 or self.tp_size > 1:
                self.output_splits_tp = self.output_splits_tp_cpu.numpy().tolist()
                self.input_splits = self.input_splits_cpu.numpy().tolist()
                self.output_splits = self.output_splits_cpu.numpy().tolist()
            if self.num_local_experts > 1:
                self.num_global_tokens_per_local_expert = self.num_global_tokens_per_local_expert_cpu.numpy()
                self.num_global_tokens_per_local_expert_T = self.num_global_tokens_per_local_expert_T_cpu.numpy()

    return tokens_per_expert

def MoEAlltoAllTokenDispatcher_preprocess(self, routing_map: torch.Tensor) -> torch.Tensor:
    """
    Preprocess token routing map for AlltoAll communication and token permutation.

    This method computes the number of tokens assigned to each expert based on the routing_map.
    It also initializes the necessary data structures for AlltoAll communication, such as input
    and output splits, and the mapping between global tokens and local experts.

    Args:
        routing_map (torch.Tensor): The mapping of tokens to experts, with shape
            [num_tokens, num_experts].

    Returns:
        torch.Tensor: Tensor containing the number of tokens assigned to local expert.
    """
    # [num_experts], number of tokens assigned to each expert from the current rank's input.
    num_local_tokens_per_expert = routing_map.sum(dim=0).long()

    if self.drop_and_pad:
        # Drop and pad the input to capacity.
        num_tokens = routing_map.size(0) * self.config.moe_router_topk
        self.capacity = get_capacity(
            num_tokens=num_tokens,
            num_experts=self.num_experts,
            capacity_factor=self.moe_expert_capacity_factor,
        )
        self.num_out_tokens = self.capacity * self.num_experts
        # [num_local_experts], number of tokens processed by each expert.
        num_tokens_per_local_expert = torch.full(
            (self.num_local_experts,),
            self.capacity * self.tp_size * self.ep_size,
            dtype=torch.long,
        )
        # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
        # to each local expert by all ranks.
        self.num_global_tokens_per_local_expert = torch.full(
            (self.num_experts * self.tp_size,),
            self.capacity,
            dtype=torch.long,
            device=self.permute_idx_device,
        )
        return num_tokens_per_local_expert
    elif self.config.moe_expert_capacity_factor is not None:
        # Drop tokens to capacity, no padding.
        # A synchronization is needed before the first
        # permutation to get the `num_out_tokens` CPU value.
        self.num_out_tokens = num_local_tokens_per_expert.sum()
        # #TODO(huang.huang): make sure num_out_tokens is not needed for permutation_1 excpet drop_and_pad
        # self.cuda_sync_point = "before_permutation_1" 

    else:
        # Dropless
        self.num_out_tokens = routing_map.size(0) * self.config.moe_router_topk
        #HACK(huang.huang): move setattr for self.cuda_sync_point below
        # if self.ep_size > 1 or self.num_local_experts > 1:
        #     # Token dropless and enable ep. A synchronization is needed before expert parallel
        #     # AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
        #     self.cuda_sync_point = "before_ep_alltoall"
        # else:
        #     # Token dropless and no ep. A synchronization is needed before the returns
        #     # to get the `tokens_per_expert` CPU value for
        #     self.cuda_sync_point = "before_finish"
        ##HACK(huang.huang)
    if self.ep_size > 1 or self.tp_size > 1:
        # ===================================================
        # Calculate input_splits, output_splits for alltoall/allgather in variable size.
        # ===================================================
        # [ep_size]. Represents the number of tokens sent by the current rank to other
        # EP ranks.
        self.input_splits = num_local_tokens_per_expert.reshape(
            self.ep_size, self.num_local_experts
        ).sum(axis=1)
        # Gather the global distribution of tokens across ranks.
        # num_global_tokens_per_expert represents the number of tokens sent to each
        # expert by all ranks.
        # [tp_size, ep_size, num_experts]
        num_global_tokens_per_expert = (
            gather_from_sequence_parallel_region(
                num_local_tokens_per_expert, group=self.tp_ep_group
            )
            .reshape(self.ep_size, self.tp_size, self.num_experts)
            .transpose(0, 1)
        )
        # [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
        num_global_tokens_per_local_expert = num_global_tokens_per_expert[
            :, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
        ].contiguous()
        # [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
        num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2)
        # [tp_size, ep_size] -> [ep_size]
        # self.output_splits represents the number of tokens received by the current rank
        # from other EP rank.
        self.output_splits = num_global_tokens_per_rank[self.tp_rank]
        # [tp_size, ep_size] -> [tp_size]
        # self.output_splits_tp represents the number of tokens received by the current
        # rank from other TP rank.
        self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
        # [tp_size, ep_size, num_local_experts] -> [num_local_experts]
        num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1))
        # A synchronization is needed before expert parallel AlltoAll communication
        # to get the `input_splits` and `output_splits` CPU values.
        self._maybe_update_cuda_sync_point("before_ep_alltoall")
    else:
        num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
            self.num_experts
        )
        num_tokens_per_local_expert = num_local_tokens_per_expert
        # A synchronization is needed before the returns
        # to get the `num_tokens_per_local_expert` CPU value.
        self._maybe_update_cuda_sync_point("before_finish")

    if self.num_local_experts > 1:
        # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
        # to each local expert by all ranks.
        self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(
            -1, self.num_local_experts
        )
        if not self.config.moe_permute_fusion:
            # A synchronization is needed before permutation 2
            # to get the `num_global_tokens_per_local_expert` CPU value.
            self._maybe_update_cuda_sync_point("before_permutation_2")

    return num_tokens_per_local_expert


def MoEAlltoAllTokenDispatcher_token_permutation(
    self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Dispatch tokens to local experts using AlltoAll communication.

    This method performs the following steps:
    1. Preprocess the routing map to get metadata for communication and permutation.
    2. Permute input tokens for AlltoAll communication.
    3. Perform expert parallel AlltoAll communication.
    4. Sort tokens by local expert (if multiple local experts exist).

    Args:
        hidden_states (torch.Tensor): Input token embeddings.
        probs (torch.Tensor): The probabilities of token to experts assignment.
        routing_map (torch.Tensor): The mapping of token to experts assignment.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - Permuted token embeddings for local experts.
            - Number of tokens per expert.
    """
    # Preprocess: Get the metadata for communication, permutation and computation operations.
    self.hidden_shape = hidden_states.shape
    self.probs = probs
    self.routing_map = routing_map
    assert probs.dim() == 2, "Expected 2D tensor for probs"
    assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
    assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
    hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
    tokens_per_expert = self.preprocess(self.routing_map)

    if self.shared_experts is not None:
        self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))

    # Permutation 1: input to AlltoAll input
    self.hidden_shape_before_permute = hidden_states.shape

    #GPU operations that data depended on need to be performed before the d2h command
    num_global_tokens_per_local_expert = self.num_global_tokens_per_local_expert.ravel()
    num_global_tokens_per_local_expert_T = self.num_global_tokens_per_local_expert.T.ravel()

    # Permutation 1: input to AlltoAll input
    tokens_per_expert = self._maybe_dtoh_and_synchronize(
        "before_permutation_1", tokens_per_expert,
        num_global_tokens_per_local_expert, num_global_tokens_per_local_expert_T
    )
    permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
        hidden_states,
        routing_map,
        num_out_tokens=self.num_out_tokens,
        fused=self.config.moe_permute_fusion,
        drop_and_pad=self.drop_and_pad,
    )
    # Perform expert parallel AlltoAll communication
    tokens_per_expert = self._maybe_dtoh_and_synchronize(
        "before_ep_alltoall", tokens_per_expert,
        num_global_tokens_per_local_expert, num_global_tokens_per_local_expert_T
    )

    global_input_tokens = all_to_all(
        self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
    )
    if self.shared_experts is not None:
        self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)

    if self.tp_size > 1:
        if self.output_splits_tp is None:
            output_split_sizes = None
        else:
            output_split_sizes = self.output_splits_tp.tolist()
        global_input_tokens = gather_from_sequence_parallel_region(
            global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
        )

    # Permutation 2: Sort tokens by local expert.
    tokens_per_expert = self._maybe_dtoh_and_synchronize(
        "before_permutation_2", tokens_per_expert,
        num_global_tokens_per_local_expert, num_global_tokens_per_local_expert_T
    )
    if self.num_local_experts > 1:
        if self.drop_and_pad:
            global_input_tokens = (
                global_input_tokens.view(
                    self.tp_size * self.ep_size,
                    self.num_local_experts,
                    self.capacity,
                    *global_input_tokens.size()[1:],
                )
                .transpose(0, 1)
                .contiguous()
                .flatten(start_dim=0, end_dim=2)
            )
        else:
            global_input_tokens = sort_chunks_by_idxs(
                global_input_tokens,
                self.num_global_tokens_per_local_expert,
                self.sort_input_by_local_experts,
                fused=self.config.moe_permute_fusion,
            )

    tokens_per_expert = self._maybe_dtoh_and_synchronize(
        "before_finish", tokens_per_expert,
        num_global_tokens_per_local_expert, num_global_tokens_per_local_expert_T
    )
    return global_input_tokens, tokens_per_expert


def MoEAlltoAllTokenDispatcher_token_unpermutation(
    self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Reverse the token permutation to restore the original order.

    This method performs the following steps:
    1. Unsort tokens by local expert (if multiple local experts exist).
    2. Perform expert parallel AlltoAll communication to restore the original order.
    3. Unpermute tokens to restore the original order.

    Args:
        hidden_states (torch.Tensor): Output from local experts.
        bias (torch.Tensor, optional): Bias tensor (not supported).

    Returns:
        Tuple[torch.Tensor, Optional[torch.Tensor]]:
            - Unpermuted token embeddings in the original order.
            - None (bias is not supported).
    """
    assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"

    # Unpermutation 2: Unsort tokens by local expert.
    if self.num_local_experts > 1:
        if self.drop_and_pad:
            hidden_states = (
                hidden_states.view(
                    self.num_local_experts,
                    self.tp_size * self.ep_size,
                    self.capacity,
                    *hidden_states.size()[1:],
                )
                .transpose(0, 1)
                .contiguous()
                .flatten(start_dim=0, end_dim=2)
            )
        else:
            hidden_states = sort_chunks_by_idxs(
                hidden_states,
                self.num_global_tokens_per_local_expert_T,
                self.restore_output_by_local_experts,
                fused=self.config.moe_permute_fusion,
            )

    if self.tp_size > 1:
        if self.output_splits_tp is None:
            input_split_sizes = None
        else:
            input_split_sizes = self.output_splits_tp.tolist()
        hidden_states = reduce_scatter_to_sequence_parallel_region(
            hidden_states, group=self.tp_group, input_split_sizes=input_split_sizes
        )

    # Perform expert parallel AlltoAll communication
    # hidden_states: [SEQL, H] -> [SEQL, H/TP]
    permutated_local_input_tokens = all_to_all(
        self.ep_group, hidden_states, self.input_splits, self.output_splits
    )
    if self.shared_experts is not None:
        self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
        self.shared_experts.post_forward_comm()

    # Unpermutation 1: AlltoAll output to output
    output = unpermute(
        permutated_local_input_tokens,
        self.reversed_local_input_permutation_mapping,
        restore_shape=self.hidden_shape_before_permute,
        probs=self.probs,
        routing_map=self.routing_map,
        fused=self.config.moe_permute_fusion,
        drop_and_pad=self.drop_and_pad,
    )

    # Reshape the output tensor
    output = output.view(self.hidden_shape)

    # Add shared experts output
    if self.shared_experts is not None:
        shared_expert_output = self.shared_experts.get_output()
        output += shared_expert_output
    return output, None
##HACK(huang.huang)

from transformer_engine.musa.pytorch.utils import replace_attr, add_attr
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
replace_attr(MoEAlltoAllTokenDispatcher, '__init__', MoEAlltoAllTokenDispatcher___init__)
add_attr(MoEAlltoAllTokenDispatcher, '_maybe_update_cuda_sync_point', MoEAlltoAllTokenDispatcher__maybe_update_cuda_sync_point)
add_attr(MoEAlltoAllTokenDispatcher, '_ensure_cpu_tensor', MoEAlltoAllTokenDispatcher__ensure_cpu_tensor)
add_attr(MoEAlltoAllTokenDispatcher, '_maybe_dtoh_and_synchronize', MoEAlltoAllTokenDispatcher__maybe_dtoh_and_synchronize)
replace_attr(MoEAlltoAllTokenDispatcher, 'preprocess', MoEAlltoAllTokenDispatcher_preprocess)
replace_attr(MoEAlltoAllTokenDispatcher, 'token_permutation', MoEAlltoAllTokenDispatcher_token_permutation)
replace_attr(MoEAlltoAllTokenDispatcher, 'token_unpermutation', MoEAlltoAllTokenDispatcher_token_unpermutation)